import numpy as np
import pickle

print("load data")
file = 'mix.data'
print(file)
with open(file, "rb") as f:
    vqa_data = pickle.load(f)
label = []
ts = []
vs = []
for v, t, l in vqa_data:
    vs.append(v)
    ts.append(t)
    label.append(l)
vs = np.array(vs)
ts = np.array(ts)
label = np.array(label)
print("load data success!")

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(300, 200),
            nn.ReLU(),
            nn.Linear(200, 3)
        )#74

        # self.t_backbone = nn.Sequential(
        #     nn.Linear(100, 100),
        #     # nn.ReLU()
        # )
        # self.v_backbone = nn.Sequential(
        #     nn.Linear(200, 100),
        #     # nn.ReLU()
        # )
        # self.clf = nn.Linear(200, 2)
        # self.backbone = nn.Sequential(
        #     nn.Linear(3000, 1000),
        #     nn.ReLU(),
        #     nn.Linear(1000, 10),
        #     nn.ReLU(),
        #     nn.Linear(10, 2)
        #     )
        # print(self.backbone)
    def forward(self, v, t):
        # v = self.v_backbone(v)
        # t = self.t_backbone(t)

        x = torch.cat([v, t], dim=1)
        # x = self.clf(x)
        x = self.backbone(x)
        return x

print("prepera model and optimizer")
best_mm = 0
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, weight_decay=1e-4)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1000, gamma = 0.1)

vs = torch.from_numpy(vs).float()
ts = torch.from_numpy(ts).float()
label = torch.from_numpy(label)

train_vs = vs[:6000]
test_vs = vs[6000:]
train_ts = ts[:6000]
test_ts = ts[6000:]
train_label = label[:6000]
test_label = label[6000:]
print(train_vs.shape, train_ts.shape, train_label.shape)

# for i in range(2000):
#     output = model(train_vs, train_ts)
#     loss = torch.nn.CrossEntropyLoss()(output, train_label)
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#     #scheduler.step()
#     if i%100 == 0:
#         _, pred = output.max(1)
#         train_correct = (pred == train_label).sum().item()
#         print("train acc:", train_correct/len(train_label))
#         with torch.no_grad():
#             test_output = model(test_vs, test_ts)
#             _, pred = test_output.max(1)
#             eval_correct = (pred == test_label).sum().item()
#             print("test acc:", eval_correct/len(test_label))
#             if eval_correct/len(test_label) > best_mm:
#                 best_mm = eval_correct/len(test_label)




class VModel(nn.Module):
    def __init__(self):
        super(VModel, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 3)
        )#74

        print(self.backbone)
    def forward(self, v):

        x = self.backbone(v)
        return x


model = VModel()
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, weight_decay=1e-4)
best_v = 0

labels_dist = {"0":0, "1":0, '2':0}
for l in test_label:
    labels_dist[str(l.item())] += 1
print(labels_dist)

preds_dist = {"0":0, "1":0, '2':0}

for i in range(1000):
    output = model(train_vs)
    loss = torch.nn.CrossEntropyLoss()(output, train_label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        _, pred = output.max(1)
        train_correct = (pred == train_label).sum().item()
        print("v train acc:", train_correct/4000)
        with torch.no_grad():
            test_output = model(test_vs)
            _, pred = test_output.max(1)
            eval_correct = (pred == test_label).sum().item()
            for j in range(len(pred)):
                if pred[j] == test_label[j]:
                    preds_dist[str(test_label[j].item())] += 1
            preds_dist["0"] = preds_dist["0"] / labels_dist["0"]
            preds_dist["1"] = preds_dist["1"] / labels_dist["1"]
            preds_dist["2"] = preds_dist["2"] / labels_dist["2"]

            print("v test acc:", eval_correct/len(test_label))
            best_v = eval_correct/len(test_label) if eval_correct/len(test_label)>best_v else best_v
            print(preds_dist)
            preds_dist = {"0": 0, "1": 0, '2':0}

eval_correct = (pred == test_label).sum().item()
print(eval_correct/len(test_label))
confusion_matrix = [
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
]
for i in range(len(test_label)):
    confusion_matrix[test_label[i].item()][pred[i].item()] += 1
print(confusion_matrix)
import sys
sys.exit()
class TModel(nn.Module):
    def __init__(self):
        super(TModel, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 3)
        )#74

        print(self.backbone)
    def forward(self, v):

        x = self.backbone(v)
        return x


model = TModel()
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, weight_decay=1e-4)
best_t = 0

for i in range(1000):
    output = model(train_ts)
    loss = torch.nn.CrossEntropyLoss()(output, train_label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        _, pred = output.max(1)
        train_correct = (pred == train_label).sum().item()
        print("t train acc:", train_correct/len(train_label))
        with torch.no_grad():
            test_output = model(test_ts)
            _, pred = test_output.max(1)
            eval_correct = (pred == test_label).sum().item()
            for j in range(len(pred)):
                if pred[j] == test_label[j]:
                    preds_dist[str(test_label[j].item())] += 1
            preds_dist["0"] = preds_dist["0"] / labels_dist["0"]
            preds_dist["1"] = preds_dist["1"] / labels_dist["1"]
            preds_dist["2"] = preds_dist["2"] / labels_dist["2"]

            print("t test acc:", eval_correct / len(test_label))
            best_t = eval_correct / len(test_label) if eval_correct / len(test_label) > best_t else best_t
            print(preds_dist)
            preds_dist = {"0": 0, "1": 0, '2': 0}

print("file:{}, best_mm:{}, best_v:{}, best_t:{}".format(file, best_mm, best_v, best_t))

